/*

 * Licensed to the Apache Software Foundation (ASF) under one

 * or more contributor license agreements.  See the NOTICE file

 * distributed with this work for additional information

 * regarding copyright ownership.  The ASF licenses this file

 * to you under the Apache License, Version 2.0 (the

 * "License"); you may not use this file except in compliance

 * with the License.  You may obtain a copy of the License at

 *

 *     http://www.apache.org/licenses/LICENSE-2.0

 *

 * Unless required by applicable law or agreed to in writing, software

 * distributed under the License is distributed on an "AS IS" BASIS,

 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

 * See the License for the specific language governing permissions and

 * limitations under the License.

 */

package com.bff.gaia.unified.fn.harness;



import com.bff.gaia.unified.fn.harness.control.BundleSplitListener;

import com.bff.gaia.unified.fn.harness.data.UnifiedFnDataClient;

import com.bff.gaia.unified.fn.harness.data.PCollectionConsumerRegistry;

import com.bff.gaia.unified.fn.harness.data.PTransformFunctionRegistry;

import com.bff.gaia.unified.fn.harness.state.UnifiedFnStateClient;

import com.bff.gaia.unified.fn.harness.state.SideInputSpec;

import com.bff.gaia.unified.model.pipeline.v1.RunnerApi;

import com.bff.gaia.unified.model.pipeline.v1.RunnerApi.PCollection;

import com.bff.gaia.unified.model.pipeline.v1.RunnerApi.PTransform;

import com.bff.gaia.unified.model.pipeline.v1.RunnerApi.ParDoPayload;

import com.bff.gaia.unified.runners.core.construction.PCollectionViewTranslation;

import com.bff.gaia.unified.runners.core.construction.ParDoTranslation;

import com.bff.gaia.unified.runners.core.construction.RehydratedComponents;

import com.bff.gaia.unified.runners.core.construction.Timer;

import com.bff.gaia.unified.sdk.Pipeline;

import com.bff.gaia.unified.sdk.coders.Coder;

import com.bff.gaia.unified.sdk.coders.KvCoder;

import com.bff.gaia.unified.sdk.fn.data.FnDataReceiver;

import com.bff.gaia.unified.sdk.options.PipelineOptions;

import com.bff.gaia.unified.sdk.schemas.SchemaCoder;

import com.bff.gaia.unified.sdk.state.TimeDomain;

import com.bff.gaia.unified.sdk.transforms.DoFn;

import com.bff.gaia.unified.sdk.transforms.Materializations;

import com.bff.gaia.unified.sdk.transforms.reflect.DoFnSignature;

import com.bff.gaia.unified.sdk.transforms.reflect.DoFnSignatures;

import com.bff.gaia.unified.sdk.transforms.windowing.BoundedWindow;

import com.bff.gaia.unified.sdk.util.WindowedValue;

import com.bff.gaia.unified.sdk.util.WindowedValue.WindowedValueCoder;

import com.bff.gaia.unified.sdk.values.KV;

import com.bff.gaia.unified.sdk.values.TupleTag;

import com.bff.gaia.unified.sdk.values.WindowingStrategy;

import com.bff.gaia.unified.vendor.guava.com.google.common.collect.*;



import java.io.IOException;

import java.util.Map;

import java.util.function.Supplier;



import static com.bff.gaia.unified.vendor.guava.com.google.common.base.Preconditions.checkArgument;



/** A {@link PTransformRunnerFactory} for transforms invoking a {@link DoFn}. */

abstract class DoFnPTransformRunnerFactory<

        TransformInputT,

        FnInputT,

        OutputT,

        RunnerT extends DoFnPTransformRunnerFactory.DoFnPTransformRunner<TransformInputT>>

    implements PTransformRunnerFactory<RunnerT> {

  interface DoFnPTransformRunner<T> {

    void startBundle() throws Exception;



    void processElement(WindowedValue<T> input) throws Exception;



    void processTimer(

		String timerId, TimeDomain timeDomain, WindowedValue<KV<Object, Timer>> input);



    void finishBundle() throws Exception;

  }



  @Override

  public final RunnerT createRunnerForPTransform(

      PipelineOptions pipelineOptions,

      UnifiedFnDataClient unifiedFnDataClient,

      UnifiedFnStateClient unifiedFnStateClient,

      String pTransformId,

      PTransform pTransform,

      Supplier<String> processBundleInstructionId,

      Map<String, PCollection> pCollections,

      Map<String, RunnerApi.Coder> coders,

      Map<String, RunnerApi.WindowingStrategy> windowingStrategies,

      PCollectionConsumerRegistry pCollectionConsumerRegistry,

      PTransformFunctionRegistry startFunctionRegistry,

      PTransformFunctionRegistry finishFunctionRegistry,

      BundleSplitListener splitListener) {

    Context<FnInputT, OutputT> context =

        new Context<>(

            pipelineOptions,

            unifiedFnStateClient,

            pTransformId,

            pTransform,

            processBundleInstructionId,

            pCollections,

            coders,

            windowingStrategies,

            pCollectionConsumerRegistry,

            splitListener);



    RunnerT runner = createRunner(context);



    // Register the appropriate handlers.

    startFunctionRegistry.register(pTransformId, runner::startBundle);

    Iterable<String> mainInput =

        Sets.difference(

            pTransform.getInputsMap().keySet(),

            Sets.union(

                context.parDoPayload.getSideInputsMap().keySet(),

                context.parDoPayload.getTimerSpecsMap().keySet()));

    for (String localInputName : mainInput) {

      pCollectionConsumerRegistry.register(

          pTransform.getInputsOrThrow(localInputName),

          pTransformId,

          (FnDataReceiver) (FnDataReceiver<WindowedValue<TransformInputT>>) runner::processElement);

    }



    // Register as a consumer for each timer PCollection.

    for (String localName : context.parDoPayload.getTimerSpecsMap().keySet()) {

      TimeDomain timeDomain =

          DoFnSignatures.getTimerSpecOrThrow(

                  context.doFnSignature.timerDeclarations().get(localName), context.doFn)

              .getTimeDomain();

      pCollectionConsumerRegistry.register(

          pTransform.getInputsOrThrow(localName),

          pTransformId,

          (FnDataReceiver)

              timer ->

                  runner.processTimer(

                      localName, timeDomain, (WindowedValue<KV<Object, Timer>>) timer));

    }



    finishFunctionRegistry.register(pTransformId, runner::finishBundle);

    return runner;

  }



  abstract RunnerT createRunner(Context<FnInputT, OutputT> context);



  static class Context<InputT, OutputT> {

    final PipelineOptions pipelineOptions;

    final UnifiedFnStateClient unifiedFnStateClient;

    final String ptransformId;

    final PTransform pTransform;

    final Supplier<String> processBundleInstructionId;

    final RehydratedComponents rehydratedComponents;

    final DoFn<InputT, OutputT> doFn;

    final DoFnSignature doFnSignature;

    final TupleTag<OutputT> mainOutputTag;

    final Coder<?> inputCoder;

    final SchemaCoder<InputT> schemaCoder;

    final Coder<?> keyCoder;

    final SchemaCoder<OutputT> mainOutputSchemaCoder;

    final Coder<? extends BoundedWindow> windowCoder;

    final WindowingStrategy<InputT, ?> windowingStrategy;

    final Map<TupleTag<?>, SideInputSpec> tagToSideInputSpecMap;

    Map<TupleTag<?>, Coder<?>> outputCoders;

    final ParDoPayload parDoPayload;

    final ListMultimap<String, FnDataReceiver<WindowedValue<?>>> localNameToConsumer;

    final BundleSplitListener splitListener;



    Context(

        PipelineOptions pipelineOptions,

        UnifiedFnStateClient unifiedFnStateClient,

        String ptransformId,

        PTransform pTransform,

        Supplier<String> processBundleInstructionId,

        Map<String, PCollection> pCollections,

        Map<String, RunnerApi.Coder> coders,

        Map<String, RunnerApi.WindowingStrategy> windowingStrategies,

        PCollectionConsumerRegistry pCollectionConsumerRegistry,

        BundleSplitListener splitListener) {

      this.pipelineOptions = pipelineOptions;

      this.unifiedFnStateClient = unifiedFnStateClient;

      this.ptransformId = ptransformId;

      this.pTransform = pTransform;

      this.processBundleInstructionId = processBundleInstructionId;

      ImmutableMap.Builder<TupleTag<?>, SideInputSpec> tagToSideInputSpecMapBuilder =

          ImmutableMap.builder();

      try {

        rehydratedComponents =

            RehydratedComponents.forComponents(

                    RunnerApi.Components.newBuilder()

                        .putAllCoders(coders)

                        .putAllPcollections(pCollections)

                        .putAllWindowingStrategies(windowingStrategies)

                        .build())

                .withPipeline(Pipeline.create());

        parDoPayload = ParDoPayload.parseFrom(pTransform.getSpec().getPayload());

        doFn = (DoFn) ParDoTranslation.getDoFn(parDoPayload);

        doFnSignature = DoFnSignatures.signatureForDoFn(doFn);

        mainOutputTag = (TupleTag) ParDoTranslation.getMainOutputTag(parDoPayload);

        String mainInputTag =

            Iterables.getOnlyElement(

                Sets.difference(

                    pTransform.getInputsMap().keySet(),

                    Sets.union(

                        parDoPayload.getSideInputsMap().keySet(),

                        parDoPayload.getTimerSpecsMap().keySet())));

        PCollection mainInput = pCollections.get(pTransform.getInputsOrThrow(mainInputTag));

        inputCoder = rehydratedComponents.getCoder(mainInput.getCoderId());

        if (inputCoder instanceof KvCoder

            // TODO: Stop passing windowed value coders within PCollections.

            || (inputCoder instanceof WindowedValue.WindowedValueCoder

                && (((WindowedValueCoder) inputCoder).getValueCoder() instanceof KvCoder))) {

          this.keyCoder =

              inputCoder instanceof WindowedValue.WindowedValueCoder

                  ? ((KvCoder) ((WindowedValueCoder) inputCoder).getValueCoder()).getKeyCoder()

                  : ((KvCoder) inputCoder).getKeyCoder();

        } else {

          this.keyCoder = null;

        }

        if (inputCoder instanceof SchemaCoder

            // TODO: Stop passing windowed value coders within PCollections.

            || (inputCoder instanceof WindowedValue.WindowedValueCoder

                && (((WindowedValueCoder) inputCoder).getValueCoder() instanceof SchemaCoder))) {

          this.schemaCoder =

              inputCoder instanceof WindowedValue.WindowedValueCoder

                  ? (SchemaCoder<InputT>) ((WindowedValueCoder) inputCoder).getValueCoder()

                  : ((SchemaCoder<InputT>) inputCoder);

        } else {

          this.schemaCoder = null;

        }



        windowingStrategy =

            (WindowingStrategy)

                rehydratedComponents.getWindowingStrategy(mainInput.getWindowingStrategyId());

        windowCoder = windowingStrategy.getWindowFn().windowCoder();



        outputCoders = Maps.newHashMap();

        for (Map.Entry<String, String> entry : pTransform.getOutputsMap().entrySet()) {

          TupleTag<?> outputTag = new TupleTag<>(entry.getKey());

          RunnerApi.PCollection outputPCollection = pCollections.get(entry.getValue());

          Coder<?> outputCoder = rehydratedComponents.getCoder(outputPCollection.getCoderId());

          if (outputCoder instanceof WindowedValue.WindowedValueCoder) {

            outputCoder = ((WindowedValueCoder) outputCoder).getValueCoder();

          }

          outputCoders.put(outputTag, outputCoder);

        }

        Coder<OutputT> outputCoder = (Coder<OutputT>) outputCoders.get(mainOutputTag);

        mainOutputSchemaCoder =

            (outputCoder instanceof SchemaCoder) ? (SchemaCoder<OutputT>) outputCoder : null;



        // Build the map from tag id to side input specification

        for (Map.Entry<String, RunnerApi.SideInput> entry :

            parDoPayload.getSideInputsMap().entrySet()) {

          String sideInputTag = entry.getKey();

          RunnerApi.SideInput sideInput = entry.getValue();

          checkArgument(

              Materializations.MULTIMAP_MATERIALIZATION_URN.equals(

                  sideInput.getAccessPattern().getUrn()),

              "This SDK is only capable of dealing with %s materializations "

                  + "but was asked to handle %s for PCollectionView with tag %s.",

              Materializations.MULTIMAP_MATERIALIZATION_URN,

              sideInput.getAccessPattern().getUrn(),

              sideInputTag);



          PCollection sideInputPCollection =

              pCollections.get(pTransform.getInputsOrThrow(sideInputTag));

          WindowingStrategy sideInputWindowingStrategy =

              rehydratedComponents.getWindowingStrategy(

                  sideInputPCollection.getWindowingStrategyId());

          tagToSideInputSpecMapBuilder.put(

              new TupleTag<>(entry.getKey()),

              SideInputSpec.create(

                  rehydratedComponents.getCoder(sideInputPCollection.getCoderId()),

                  sideInputWindowingStrategy.getWindowFn().windowCoder(),

                  PCollectionViewTranslation.viewFnFromProto(entry.getValue().getViewFn()),

                  PCollectionViewTranslation.windowMappingFnFromProto(

                      entry.getValue().getWindowMappingFn())));

        }

      } catch (IOException exn) {

        throw new IllegalArgumentException("Malformed ParDoPayload", exn);

      }



      ImmutableListMultimap.Builder<String, FnDataReceiver<WindowedValue<?>>>

          localNameToConsumerBuilder = ImmutableListMultimap.builder();

      for (Map.Entry<String, String> entry : pTransform.getOutputsMap().entrySet()) {

        localNameToConsumerBuilder.putAll(

            entry.getKey(), pCollectionConsumerRegistry.getMultiplexingConsumer(entry.getValue()));

      }

      localNameToConsumer = localNameToConsumerBuilder.build();

      tagToSideInputSpecMap = tagToSideInputSpecMapBuilder.build();

      this.splitListener = splitListener;

    }

  }

}